{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VFidSWlB6CIX"
      },
      "outputs": [],
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Lb4FT9_yTz6Z"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "from colabtools import adhoc_import\n",
        "import importlib\n",
        "from simulation_research.diffusion import ode_datasets\n",
        "from simulation_research.diffusion import diffusion_unet\n",
        "from simulation_research.diffusion import samplers\n",
        "importlib.reload(ode_datasets)\n",
        "importlib.reload(diffusion_unet)\n",
        "importlib.reload(samplers)\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib import rc\n",
        "rc('animation', html='jshtml')\n",
        "import jax.numpy as jnp\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Joy4WVQ5Bqrw"
      },
      "outputs": [],
      "source": [
        "from jax import devices,device_count\n",
        "device_count()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iyArKlOsUVs5"
      },
      "outputs": [],
      "source": [
        "tf.executing_eagerly()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i4P7SzmOsEvF"
      },
      "source": [
        "# Generate the Trajectories"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eDIphxAwsKUE"
      },
      "source": [
        "## N-Link Pendulum"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jgkNTEf-fhAR"
      },
      "outputs": [],
      "source": [
        "dt = .1\n",
        "ds = ode_datasets.NPendulum(N=2000,n=1,dt=dt)\n",
        "thetas,vs = ode_datasets.unpack(ds.Zs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N8HsUM03yuk8"
      },
      "outputs": [],
      "source": [
        "# for i in range(20):\n",
        "#   fig = plt.figure()\n",
        "#   ax = fig.add_subplot(1, 1, 1)\n",
        "#   line1, = ax.plot(ds.T_long,thetas[i,:,0])\n",
        "#   line2, = ax.plot(ds.T_long,thetas[i,:,1])#\n",
        "#   #line2, = ax.plot(ds.T_long,jnp.cos(thetas[i,:,1])+jnp.cos(thetas[i,:,0]))\n",
        "#   plt.xlabel('Time t')\n",
        "#   plt.ylabel(r'State')\n",
        "#   plt.legend([r'$\\theta_0$',r'$\\theta_1$'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tD3PUWBVQLX-"
      },
      "outputs": [],
      "source": [
        "dataset = tf.data.Dataset.from_tensor_slices(thetas)\n",
        "data_std = thetas.std()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D9LK0AFrd8-J"
      },
      "outputs": [],
      "source": [
        "jnp.sqrt(((thetas[None,:400]-thetas[:400,None])**2).sum((-1,-2))).max()/jnp.sqrt(np.prod(thetas.shape[1:]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "brLISQLhQVXz"
      },
      "outputs": [],
      "source": [
        "bs = 400\n",
        "dataiter = dataset.shuffle(len(dataset)).batch(bs).as_numpy_iterator"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8SJ3stKFsOR1"
      },
      "outputs": [],
      "source": [
        "from matplotlib import rc\n",
        "rc('animation', html='jshtml')\n",
        "#ds.animate()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y1hHzbmWZrvS"
      },
      "source": [
        "##Diffusion"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vXD6VL3jTfpO"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import jax.numpy as jnp\n",
        "from jax import random\n",
        "import jax\n",
        "import flax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SMN8fY2UthvT"
      },
      "outputs": [],
      "source": [
        "x = next(dataiter())\n",
        "t = np.random.rand(x.shape[0])\n",
        "model = diffusion_unet.UNet(diffusion_unet.unet_64_config(out_dim=x.shape[-1],base_channels=24))\n",
        "params = model.init(random.PRNGKey(42), x=x,t=t,train=False)\n",
        "x.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pajn72Q-uOps"
      },
      "outputs": [],
      "source": [
        "def count_params(params):\n",
        "  if isinstance(params, jax.numpy.ndarray):\n",
        "    return np.prod(params.shape)\n",
        "  elif isinstance(params,(dict,flax.core.frozen_dict.FrozenDict)):\n",
        "    return sum([count_params(v) for v in params.values()])\n",
        "  else:\n",
        "    assert False, type(params)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pkRzOnhCuQE7"
      },
      "outputs": [],
      "source": [
        "count_params(params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3HyQzRHSr_qh"
      },
      "source": [
        "Initialize the UNet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZpPYuYEluyJo"
      },
      "outputs": [],
      "source": [
        "x.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U0zOA0KXvdbJ"
      },
      "outputs": [],
      "source": [
        "from tqdm.auto import tqdm\n",
        "import optax\n",
        "from jax import jit\n",
        "import pandas as pd\n",
        "importlib.reload(samplers)\n",
        "\n",
        "#sigma_min = 1e-3#2e-4#2e-3\n",
        "#sigma_max = 1#100\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZFhSBePQuzjs"
      },
      "outputs": [],
      "source": [
        "key = random.PRNGKey(38)\n",
        "with Mesh(mesh_utils.create_device_mesh((device_count(),)), ('data',)):\n",
        "  for epoch in tqdm(range(601)):\n",
        "    for data in dataiter():\n",
        "      params,ema_params,opt_state,key,loss_val = update_fn(params,ema_params,opt_state,key,data)\n",
        "    if epoch % 5 == 0:\n",
        "      ema_loss = jloss(ema_params,data,key)\n",
        "      message = f'Loss epoch {epoch}: {loss_val:.3f} Ema {ema_loss:.3f}'\n",
        "      # if not epoch % 30:\n",
        "      #   val = pmetric(samplers.stochastic_sampler(denoiser,params,key,(512,)+data.shape[1:],500)[0])[0]\n",
        "      #   #message += f'     Precision: {}'\n",
        "      print(message)\n",
        "    if epoch %200 ==0:\n",
        "      print(eval_metrics(dataiter,ema_params,key))\n",
        "\n",
        "params=ema_params"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gCU1NmjKqLM-"
      },
      "outputs": [],
      "source": [
        "mb = data[:30]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z4v7WbBbQsPL"
      },
      "outputs": [],
      "source": [
        "importlib.reload(samplers)\n",
        "denoiser = jit(lambda params,x,sigma: denoised(params,x,jnp.ones(x.shape[0])*sigma,train=False))  \n",
        "def conditioning_scores(observed_values,s=.2):\n",
        "  b,n1,c = observed_values.shape\n",
        "  return jax.grad(lambda x: -jnp.sum((x.reshape(b,-1,c)[:,:n1]-observed_values)**2)/(2*s**2))\n",
        "#conditioning_scores(mb[:,:20]),\n",
        "\n",
        "  \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iqgGfIaO-POu"
      },
      "outputs": [],
      "source": [
        "importlib.reload(samplers)\n",
        "t=.001\n",
        "z = samplers.sample(denoiser,params,key,mb.shape,t,t_max)#,conditioning_scores(mb[:,:50]))\n",
        "noised_x = mb*samplers.s(t)+np.random.randn(*mb.shape)*(samplers.s(t)*samplers.sigma(t))\n",
        "import matplotlib.pyplot as plt\n",
        "i=2\n",
        "plt.plot(ds.T_long,mb[i,:,0])\n",
        "plt.plot(ds.T_long,noised_x[i,:,0])\n",
        "plt.plot(ds.T_long,z[i,:,0])\n",
        "\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend([r'GT','GT noised xt',r'Model xt'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OUsObb_pFBZh"
      },
      "outputs": [],
      "source": [
        "importlib.reload(samplers)\n",
        "nll = samplers.compute_nll(denoiser,params,key,data[:400])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mDbkYQtcFGwa"
      },
      "outputs": [],
      "source": [
        "nll.mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Kktv27hE2jZN"
      },
      "outputs": [],
      "source": [
        "importlib.reload(samplers)\n",
        "from jax import grad\n",
        "def score(x,t):\n",
        "  return (denoiser(params,x.reshape(mb.shape)/samplers.s(t),samplers.sigma(t)).reshape(-1)-x/samplers.s(t))/(samplers.s(t)*samplers.sigma(t)**2)\n",
        "dynamics = lambda t,x: grad(samplers.s)(t)*x/samplers.s(t)-(samplers.s(t)**2)*grad(samplers.sigma)(t)*score(x,t).reshape(-1)*samplers.sigma(t)\n",
        "dynamics2 = lambda t,x: (grad(samplers.s)(t)/samplers.s(t)+grad(samplers.sigma)(t)/samplers.sigma(t))*x - (grad(samplers.sigma)(t)/samplers.sigma(t))*samplers.s(t)*denoiser(params,x.reshape(mb.shape)/samplers.s(t),samplers.sigma(t)).reshape(-1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rsdaFI5jDm36"
      },
      "outputs": [],
      "source": [
        "dynamics(.99,xt.reshape(-1))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Pvgo3RXfDrw0"
      },
      "outputs": [],
      "source": [
        "dynamics2(.99,xt.reshape(-1))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CYvRKZEq2uHn"
      },
      "outputs": [],
      "source": [
        "xt = np.random.randn(*mb.shape)*samplers.s(t_max)*samplers.sigma(t_max)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hwq8vbg93nBw"
      },
      "outputs": [],
      "source": [
        "xt.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tt-V5k-991-h"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A8649rV_40u5"
      },
      "outputs": [],
      "source": [
        "jnp.max(jnp.abs(samplers.score(denoiser,params,mb.shape)(mb.reshape(-1),1.)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WlBiNZIC5EP6"
      },
      "outputs": [],
      "source": [
        "t=1.\n",
        "denoiser(params,mb/samplers.s(t),samplers.sigma(t)).reshape(-1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E-XlJWv15M8-"
      },
      "outputs": [],
      "source": [
        "denoiser(params,mb/samplers.s(t),t)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JVVhWgkO5bBD"
      },
      "outputs": [],
      "source": [
        "denoised(params,mb,jnp.ones(mb.shape[0])*samplers.sigma(t),train=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sbReawQH3Gty"
      },
      "outputs": [],
      "source": [
        "dynamics(1.,mb.reshape(-1))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UnGckHT03PtR"
      },
      "outputs": [],
      "source": [
        "1/samplers.s(t_max)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lyXySA2M_0g6"
      },
      "outputs": [],
      "source": [
        "# z = jax.random.normal(key,(64,)+input_data.shape[1:])\n",
        "# y = denoiser(z,.1)\n",
        "# import numpy as np\n",
        "# perm = np.random.permutation(z.shape[0])\n",
        "# y2 = denoiser(z[perm],.1)[np.argsort(perm)]\n",
        "# print(jnp.linalg.norm(y-y2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ioj3N1vbRHdR"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "i=5\n",
        "plt.plot(ds.T_long,mb[i,:,0])\n",
        "plt.plot(ds.T_long,z[i,:,0])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend([r'GT',r'Model'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LhlxPaa5ZijM"
      },
      "outputs": [],
      "source": [
        "data = next(dataiter())\n",
        "key = random.PRNGKey(26)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EE5N_rBEUDiM"
      },
      "outputs": [],
      "source": [
        "nll = samplers.compute_nll(denoiser,params,key,data[:400],smin=sigma_min,smax=sigma_max,num_probes=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9LAyp0c2678J"
      },
      "outputs": [],
      "source": [
        "nll.mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q_0MJE3qdlo7"
      },
      "outputs": [],
      "source": [
        "nll.std(0)/jnp.sqrt(len(nll))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gR_lt_xWI0AW"
      },
      "outputs": [],
      "source": [
        "\n",
        "noised_data = samplers.forward_process2(denoiser,params,key,data,smin=sigma_min,smax=sigma_max)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gzv1gVO7NWOD"
      },
      "outputs": [],
      "source": [
        "noised_data.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bUaxGULjTe2I"
      },
      "outputs": [],
      "source": [
        "noised_data.std()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DblhkquiJK5n"
      },
      "outputs": [],
      "source": [
        "noised_data.std()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qO0uH4NmJ4zt"
      },
      "outputs": [],
      "source": [
        "T = samplers.timesteps(30,sigma_min,sigma_max)\n",
        "print(np.sum(T[1:]-T[:-1]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2mtArA_7vtyS"
      },
      "outputs": [],
      "source": [
        "\n",
        "key = random.PRNGKey(45)\n",
        "\n",
        "s,history = samplers.stochastic_sampler(denoiser,params,key,(32,)+data.shape[1:],N=1000,smin=sigma_min,smax=sigma_max)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aIModqvIW-xg"
      },
      "outputs": [],
      "source": [
        "#stochastic_sampler(params,key,(128,)+input_data.shape[1:],N=2000)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fEU1jVUmDPlb"
      },
      "outputs": [],
      "source": [
        "s = samplers.sample(denoiser,params,random.split(key)[0],(64,)+data.shape[1:])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GJX9cnJ3lcGJ"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aE452iDGuwxc"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "plt.plot(ds.T_long,thetas[2,:,0])\n",
        "plt.plot(ds.T_long,thetas[2,:,-1])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend([r'$\\theta_0$',r'$\\theta_1$'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yOm-f1H4uxok"
      },
      "outputs": [],
      "source": [
        "for i,h in enumerate(history[::200]):\n",
        "  plt.plot(ds.T_long,h[1,:,-1],label=str(i),alpha=1/3)\n",
        "plt.plot(ds.T_long,s[1,:,-1],label=str(i))\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend()\n",
        "plt.ylim((-3,3))\n",
        "#plt.legend([r'$\\theta_0$',r'$\\theta_1$'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PJ0ojxLm2VuD"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "from ipywidgets import interact\n",
        "\n",
        "\n",
        "\n",
        "# @interact(i=(0,s.shape[0]-1))\n",
        "# def plot(i=1):\n",
        "#   fig = plt.figure()\n",
        "#   ax = fig.add_subplot(1, 1, 1)\n",
        "#   line1, = ax.plot(ds.T_long,s[i,:,0])\n",
        "#   line2, = ax.plot(ds.T_long,s[i,:,1])\n",
        "#   plt.xlabel('Time t')\n",
        "#   plt.ylabel(r'State')\n",
        "#   plt.legend([r'$\\theta_0$',r'$\\theta_1$'])\n",
        "  #plt.ylim(-2,2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uWsm4HAMFccX"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K7PSODHy8abI"
      },
      "outputs": [],
      "source": [
        "for i in range(2):\n",
        "  fig = plt.figure()\n",
        "  ax = fig.add_subplot(1, 1, 1)\n",
        "  line1, = ax.plot(ds.T_long,s[i,:,0])\n",
        "  line2, = ax.plot(ds.T_long,s[i,:,-1])\n",
        "  plt.xlabel('Time t')\n",
        "  plt.ylabel(r'State')\n",
        "  plt.legend([r'$\\theta_0$',r'$\\theta_1$'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HJ2dVfECLvj1"
      },
      "outputs": [],
      "source": [
        "for i in range(10):\n",
        "  fig = plt.figure()\n",
        "  ax = fig.add_subplot(1, 1, 1)\n",
        "  line1, = ax.plot(ds.T_long,s[i,:,0])\n",
        "  line2, = ax.plot(ds.T_long,s[i,:,-1])\n",
        "  plt.xlabel('Time t')\n",
        "  plt.ylabel(r'State')\n",
        "  plt.legend([r'$\\theta_0$',r'$\\theta_1$'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S9cmzf_ZHMrR"
      },
      "outputs": [],
      "source": [
        "key = random.PRNGKey(45)\n",
        "#s=s2#,history = samplers.stochastic_sampler(denoiser,params,key,(32,)+data.shape[1:],N=500,smin=sigma_min,smax=sigma_max)\n",
        "\n",
        "\n",
        "k = 5\n",
        "q = s[:,k:]\n",
        "v = -(q[:,:-2]-q[:,2:])/(2*(ds.T[1]-ds.T[0]))\n",
        "z = ode_datasets.pack(q[:,1:-1],(vmap(vmap(ds.M))(q[:,1:-1])@v[...,None]).squeeze(-1))\n",
        "T = ds.T_long[k+1:-1]\n",
        "z0 = z[:,0]\n",
        "z_gts = vmap(ds.integrate,(0,None),0)(z0,T)\n",
        "z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)\n",
        "z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JqbHZCCgoilB"
      },
      "outputs": [],
      "source": [
        "q.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GjgmAsigXkC_"
      },
      "outputs": [],
      "source": [
        "for i in range(10):\n",
        "  fig = plt.figure()\n",
        "  ax = fig.add_subplot(1, 1, 1)\n",
        "  line1, = ax.plot(T,z_gts[i,:,0])\n",
        "  line2, = ax.plot(T,z[i,:,0])\n",
        "  line3, = ax.plot(T,z_pert[i,:,0])\n",
        "  plt.xlabel('Time t')\n",
        "  plt.ylabel(r'State')\n",
        "  plt.legend(['gt','model','pert'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ujIbtIHyxMaw"
      },
      "outputs": [],
      "source": [
        "for i in range(10):\n",
        "  fig = plt.figure()\n",
        "  ax = fig.add_subplot(1, 1, 1)\n",
        "  line1, = ax.plot(T,z_gts[i,:,0])\n",
        "  line2, = ax.plot(T,z[i,:,0])\n",
        "  line3, = ax.plot(T,z_gts[i,:,-1])\n",
        "  line5, = ax.plot(T,z[i,:,-1])\n",
        "  plt.xlabel('Time t')\n",
        "  plt.ylabel(r'State')\n",
        "  plt.legend([r'$\\theta_0$ gt',r'$\\theta_0$ model',r'v gt', r'v model'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HjnKwVZdORrg"
      },
      "outputs": [],
      "source": [
        "pmetric(s)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "divghYKhaDhw"
      },
      "outputs": [],
      "source": [
        "for pred in [z,z_pert,z_random]:\n",
        "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T,rel_errs)\n",
        "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Prediction Error')\n",
        "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WDm9CtIbTEIs"
      },
      "outputs": [],
      "source": [
        "for pred in [z,z_pert,z_random]:\n",
        "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T,rel_errs)\n",
        "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Prediction Error')\n",
        "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UM_G0V6QMnjL"
      },
      "outputs": [],
      "source": [
        "H_gts = vmap(vmap(ds.H))(z_gts)\n",
        "for pred in [z,z_pert,z_random]:\n",
        "  Hs = vmap(vmap(ds.H))(pred)\n",
        "  clamped_errs = jax.lax.clamp(1e-3,jnp.abs(Hs-H_gts)/jnp.abs(Hs*H_gts),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T,rel_errs)\n",
        "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Energy Error')\n",
        "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Am0N9YRAFlxS"
      },
      "outputs": [],
      "source": [
        "\n",
        "for H in Hs:\n",
        "  plt.plot(ds.T_long[1:-1],jnp.abs(H-H[0]))\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Energy Error')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xjtnSnyqqDGD"
      },
      "outputs": [],
      "source": [
        "metric_vals =[]\n",
        "metric_stds = []\n",
        "Ns = [25,50,100,200,500,1000,2000]\n",
        "for N in Ns:\n",
        "  s,_ = samplers.stochastic_sampler(denoiser,params,key,(256,)+data.shape[1:],N=N,smin=sigma_min,smax=sigma_max)\n",
        "  mean,std = pmetric(s)\n",
        "  metric_vals.append(mean)\n",
        "  metric_stds.append(std)\n",
        "metric_vals = np.array(metric_vals)\n",
        "metric_stds = np.array(metric_stds)\n",
        "plt.plot(Ns,metric_vals)\n",
        "plt.fill_between(Ns, metric_vals/metric_stds, metric_vals*metric_stds,alpha=.3)\n",
        "plt.xlabel('Sampler steps')\n",
        "plt.ylabel('Pmetric value')\n",
        "plt.xscale('log')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vPNbx_4GZOI4"
      },
      "outputs": [],
      "source": [
        "plt.plot(Ns,metric_vals)\n",
        "plt.fill_between(Ns, metric_vals/metric_stds, metric_vals*metric_stds,alpha=.3)\n",
        "plt.xlabel('Sampler steps')\n",
        "plt.ylabel('Pmetric value')\n",
        "plt.xscale('log')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XKmrvn60ex2J"
      },
      "outputs": [],
      "source": [
        "data = next(dataiter())\n",
        "key = random.PRNGKey(26)\n",
        "noised_x,sigma = noise_input(data,key)\n",
        "weighting = (sigma**2+data_std**2)/(sigma*data_std)**2\n",
        "losses = jnp.mean(((denoised(ema_params,noised_x,sigma)-data)**2)*weighting[:,None,None],axis=(-1,-2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FBpfHonffl_H"
      },
      "outputs": [],
      "source": [
        "\n",
        "plt.scatter(sigma,losses)\n",
        "#plt.plot(np.sort(sigma),jax.scipy.stats.norm.pdf(np.log(np.sort(sigma)),mu,std),color='y')\n",
        "#plt.hline(1e-1)\n",
        "#plt.scatter(sigma,weighting)\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "plt.ylabel('weighted loss')\n",
        "plt.xlabel(r'$\\sigma$')\n",
        "plt.legend(['loss values','sigma sample pdf'][:1:-1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hRLWfpQ_29Ad"
      },
      "outputs": [],
      "source": [
        "x = np.random.randn(256)\n",
        "\n",
        "binomial = [np.array([1., 1.])/2]\n",
        "for _ in range(int(np.floor(np.log2(len(x))))):\n",
        "  sqr = np.convolve(binomial[-1],binomial[-1])\n",
        "  #binomial[-1] /= sqr[sqr.shape[0]//2+1]\n",
        "  #binomial.append(sqr/sqr.sum())\n",
        "  binomial.append(sqr/sqr[sqr.shape[0]//2+1])\n",
        "  \n",
        "binomial = [np.array([1.])]+binomial[:-1]\n",
        "def blur(z):\n",
        "  return jnp.convolve(binomial[-1],z,mode='same')\n",
        "\n",
        "#vblur = vmap(vmap(blur,0,0),2,2)\n",
        "\n",
        "def vblur(z):\n",
        "  s = jnp.fft.rfft(z,axis=1)\n",
        "  f = 1+jnp.abs(jnp.fft.fftfreq(z.shape[1])[:s.shape[1]])*s.shape[1]\n",
        "  scaled = s/f[None,:,None]**.5\n",
        "  scaled = scaled/jnp.mean(jnp.abs(scaled),axis=1,keepdims=True)\n",
        "  noise = jnp.fft.irfft(scaled,axis=1)\n",
        "  return noise\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qrka-iN7h5Ke"
      },
      "outputs": [],
      "source": [
        "x = np.random.randn(300)\n",
        "\n",
        "binomial = [np.array([1., 1.])/2]\n",
        "for _ in range(int(np.floor(np.log2(len(x))))):\n",
        "  sqr = np.convolve(binomial[-1],binomial[-1])\n",
        "  #binomial[-1] /= sqr[sqr.shape[0]//2+1]\n",
        "  binomial.append(sqr/sqr.sum())\n",
        "  #binomial.append(sqr/sqr[sqr.shape[0]//2+1])\n",
        "  \n",
        "binomial = [np.array([1.])]+binomial[:-1]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZLaW91Crq9l8"
      },
      "outputs": [],
      "source": [
        "blurred = [jax.scipy.signal.convolve(x,bin,mode='same') for bin in binomial]\n",
        "blurred.append(jnp.cumsum(x)/np.sqrt(len(x)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pmeV_A3dsrIv"
      },
      "outputs": [],
      "source": [
        "for i,bx in enumerate(blurred):\n",
        "  plt.plot(bx,label=str(2**i))\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z5C-esHKsuDN"
      },
      "outputs": [],
      "source": [
        "freq = np.fft.fftfreq(x.shape[0])[:x.shape[0]//2]*x.shape[0]\n",
        "for i,bx in enumerate(blurred):\n",
        "  plt.plot(freq, jnp.abs(np.fft.fft(bx)[:x.shape[0]//2]),label=str(2**i))\n",
        "\n",
        "plt.plot(freq,1/freq**2,label='brown')\n",
        "plt.plot(freq,1/freq,label='pink')\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d9o2ahEjyG55"
      },
      "outputs": [],
      "source": [
        "#plt.plot(freq, jnp.abs(np.fft.fft(vblur(x[None,:,None])[0,:,0])[:x.shape[0]//2]),label=str(2**i))\n",
        "plt.plot(freq,1/freq**2,label='brown')\n",
        "plt.plot(freq,1/freq,label='pink')\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hbj5vWgi9veU"
      },
      "outputs": [],
      "source": [
        "plt.plot(vblur(x[None,:,None])[0,:,0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FJttK8iqs8Td"
      },
      "outputs": [],
      "source": [
        "[(bx**2).mean() for x in blurred]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Sj1izTwsxjw5"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu",
        "kind": "private"
      },
      "name": "double_pendulum_diffusion.ipynb",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
